{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m\u001b[1m    Status\u001b[22m\u001b[39m `/mnt/E4E0A9C0E0A998F6/github/ReinforcementLearningAnIntroduction.jl/notebooks/Project.toml`\n",
      " \u001b[90m [31c24e10]\u001b[39m\u001b[37m Distributions v0.22.4\u001b[39m\n",
      " \u001b[90m [91a5bcdd]\u001b[39m\u001b[37m Plots v0.29.1\u001b[39m\n",
      " \u001b[90m [02c1da58]\u001b[39m\u001b[37m ReinforcementLearningAnIntroduction v0.2.0 [`..`]\u001b[39m\n",
      " \u001b[90m [2913bbd2]\u001b[39m\u001b[37m StatsBase v0.32.1\u001b[39m\n",
      " \u001b[90m [f3b207a7]\u001b[39m\u001b[37m StatsPlots v0.12.0\u001b[39m\n",
      " \u001b[90m [2f01184e]\u001b[39m\u001b[37m SparseArrays \u001b[39m\n"
     ]
    }
   ],
   "source": [
    "]st"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Precompiling ReinforcementLearningAnIntroduction [02c1da58-b9a1-11e8-0212-f9611b8fe936]\n",
      "└ @ Base loading.jl:1273\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "___\n",
       "___\n",
       "___\n",
       "isdone = [false], winner = [nothing]\n"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using ReinforcementLearningAnIntroduction\n",
    "using ReinforcementLearningAnIntroduction.TicTacToe\n",
    "env = TicTacToeEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "X"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_current_player(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5478, 10)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "observation_space, action_space = get_observation_space(env), get_action_space(env)\n",
    "nstates, nactions = length(observation_space), length(action_space)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you are curious why there are `5478` states, you may see the discussions [here](https://math.stackexchange.com/questions/485752/tictactoe-state-space-choose-calculation/485852)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(reward = 0.0, terminal = false, state = 4151, legal_actions_mask = Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 0])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "observe(env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we'll use the Monte Carlo based method to estimate the value of each state for each player. Think about this, if we have the precise estimation of each state after taking some specific observation according to current observation, then we can just choose the action which leads to the maximum estimation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So we can create a table for each player first. By default we can set the estimations of all the states to `0.0`. Usually it won't be a problem, but here we can initialize it with a better starting point. For each state, we can check that if the state is a final state or not and set the initial estimation accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "init_table (generic function with 1 method)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function init_table(role)\n",
    "    table = zeros(nstates)\n",
    "    for i in 1:nstates\n",
    "        s = TicTacToe.ID2STATE[i]\n",
    "        isdone, winner = TicTacToe.STATES_INFO[s]\n",
    "        if isdone\n",
    "            if winner === nothing\n",
    "                table[i] = 0.5\n",
    "            elseif winner === role\n",
    "                table[i] = 1.\n",
    "            else\n",
    "                table[i] = 0.\n",
    "            end\n",
    "        else\n",
    "            table[i] = 0.5\n",
    "        end\n",
    "    end\n",
    "    table\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we wrap the table in a `TabularApproximator`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "V1 = TabularApproximator(init_table(TicTacToe.offensive));\n",
    "V2 = TabularApproximator(init_table(TicTacToe.defensive));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And wrap it around a MonteCarloLearner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MonteCarloLearner{ReinforcementLearningAnIntroduction.EveryVisit,TabularApproximator{1,Array{Float64,1}},CachedSampleAvg{Float64},ReinforcementLearningAnIntroduction.NoSampling}(TabularApproximator{1,Array{Float64,1}}([0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), 1.0, 0.1, CachedSampleAvg{Float64}(Dict{Float64,SampleAvg}()))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner_1 = MonteCarloLearner(;approximator=V1, α=0.1, kind=EVERY_VISIT)\n",
    "learner_2 = MonteCarloLearner(;approximator=V2, α=0.1, kind=EVERY_VISIT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then the learner is assemble into a policy.\n",
    "\n",
    "A policy is a mapping from states to actions. Considering that we already have the estimations of states, a simple policy would be checking the estimation of the following up states and select one action which will result to the best state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "create_mapping (generic function with 1 method)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function create_mapping(role)\n",
    "    (obs, learner) -> begin\n",
    "        mask = get_legal_actions_mask(obs)\n",
    "        [\n",
    "            mask[a] ? learner(StateOverriddenObs(obs=obs, state=TicTacToe.get_next_state_id(get_state(obs), role, a))) : 0.  # a dummy value     \n",
    "            for a in action_space\n",
    "        ]\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ϵ = 0.01\n",
    "\n",
    "π_1 = VBasedPolicy(\n",
    "    learner = learner_1,\n",
    "    mapping = create_mapping(TicTacToe.offensive),\n",
    "    explorer = EpsilonGreedyExplorer(ϵ),\n",
    "    )\n",
    "\n",
    "π_2 = VBasedPolicy(\n",
    "    learner = learner_2,\n",
    "    mapping = create_mapping(TicTacToe.defensive),\n",
    "    explorer = EpsilonGreedyExplorer(ϵ),\n",
    "    );\n",
    "\n",
    "agent_1 = Agent(\n",
    "    policy = π_1,\n",
    "    trajectory = EpisodicCompactSARTSATrajectory(),\n",
    "    role=TicTacToe.offensive\n",
    "    );\n",
    "\n",
    "agent_2 = Agent(\n",
    "    policy = π_2,\n",
    "    trajectory = EpisodicCompactSARTSATrajectory(),\n",
    "    role=TicTacToe.defensive\n",
    "    );\n",
    "\n",
    "reset!(env)\n",
    "\n",
    "agents = (agent_1, agent_2);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:42\u001b[39m8:41\u001b[39m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2-element Array{EmptyHook,1}:\n",
       " EmptyHook()\n",
       " EmptyHook()"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "run((agent_1, agent_2), env, StopAfterEpisode(1000000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_1.policy.explorer.ϵ_stable = 0.0\n",
    "agent_2.policy.explorer.ϵ_stable = 0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now it's your turn to play this game!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "play (generic function with 1 method)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function read_action_from_stdin()\n",
    "    print(\"Your input:\")\n",
    "    input = parse(Int, readline())\n",
    "    !in(input, 1:9) && error(\"invalid input!\")\n",
    "    input\n",
    "end\n",
    "\n",
    "function play()\n",
    "    env = TicTacToeEnv()\n",
    "    println(\"\"\"You play first!\n",
    "    1 4 7\n",
    "    2 5 8\n",
    "    3 6 9\"\"\")\n",
    "    while true\n",
    "        action = read_action_from_stdin()\n",
    "        env(action)\n",
    "        println(env)\n",
    "        obs = observe(env, TicTacToe.offensive)\n",
    "        if get_terminal(obs)\n",
    "            if get_reward(obs) == 0.5\n",
    "                println(\"Tie!\")\n",
    "            elseif get_reward(obs) == 1.0 \n",
    "                println(\"You win!\")\n",
    "            else\n",
    "                println(\"Invalid input!\")\n",
    "            end\n",
    "            break\n",
    "        end\n",
    "\n",
    "        env(agent_2(PRE_ACT_STAGE, observe(env)))\n",
    "        println(env)\n",
    "        obs = observe(env, TicTacToe.defensive)\n",
    "        if get_terminal(obs)\n",
    "            if get_reward(obs) == 0.5\n",
    "                println(\"Tie!\")\n",
    "            elseif get_reward(obs) == 1.0 \n",
    "                println(\"Your lose!\")\n",
    "            else\n",
    "                println(\"You win!\")\n",
    "            end\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You play first!\n",
      "1 4 7\n",
      "2 5 8\n",
      "3 6 9\n",
      "Your input:stdin> 5\n",
      "___\n",
      "_X_\n",
      "___\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "___\n",
      "_X_\n",
      "O__\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "Your input:stdin> 6\n",
      "___\n",
      "_X_\n",
      "OX_\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "_O_\n",
      "_X_\n",
      "OX_\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "Your input:stdin> 8\n",
      "_O_\n",
      "_XX\n",
      "OX_\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "_O_\n",
      "OXX\n",
      "OX_\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "Your input:stdin> 1\n",
      "XO_\n",
      "OXX\n",
      "OX_\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "XO_\n",
      "OXX\n",
      "OXO\n",
      "isdone = [false], winner = [nothing]\n",
      "\n",
      "Your input:stdin> 7\n",
      "XOX\n",
      "OXX\n",
      "OXO\n",
      "isdone = [true], winner = [nothing]\n",
      "\n",
      "Tie!\n"
     ]
    }
   ],
   "source": [
    "play()"
   ]
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Julia 1.3.0",
   "language": "julia",
   "name": "julia-1.3"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.3.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
